Remove additional weight prefetching all gathers #3412
Remove additional weight prefetching all gathers #3412copybara-service[bot] merged 7 commits intomainfrom
Conversation
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
2b3da04 to
680f1ad
Compare
2d66c61 to
0c52610
Compare
a1f6db3 to
8912464
Compare
|
🤖 Hi @NuojCheng, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
8912464 to
efda1e8
Compare
There was a problem hiding this comment.
This Pull Request introduces significant refactoring to the circular pipeline parallelism implementation in MaxText, specifically targeting weight prefetching efficiency and communication overlap. The transition to a custom_vjp for pipeline stage execution is a sophisticated improvement that allows for better control over memory and communication during both forward and backward passes.
🔍 General Feedback
- Significant Logic Refactoring: The transition to a sliding window of size 2 for block-sharded weights (
w_curr,w_next) correctly addresses the pipeline delay and is much more efficient than fetching both current and next weights at every repeat. - Custom VJP Implementation: The
custom_vjpinpipeline_utils.pycorrectly handles the linear transposition of the prefetching logic, ensuring that gradients for the pipeline weights are accumulated properly through the repeats. - Correctness Concern: A critical sharding mismatch was identified in
gather_microbatch_inputs_vmapwhenShardMode.EXPLICITis used. This should be addressed before merging. - Config Improvement: The new
pipeline-large-moe-cp.ymlconfiguration correctly adapts logical axes for large-scale MoE jobs, reflecting DeepSeek-style model structures.
gobbleturk
left a comment
There was a problem hiding this comment.
This is awesome!
Is there an existing test that protects the correctness of this? If not we definitely should add one in this PR
efda1e8 to
60dbacd
Compare
The loss and gradients correctness are protected by this test: maxtext/tests/unit/pipeline_parallelism_test.py Lines 283 to 299 in 60dbacd |
eb931a1 to
132c608
Compare
69c99bd to
049ba3f
Compare
Description
This PR does following 2 things:
Tests
Please describe how you tested this change, and include any instructions and/or
commands to reproduce.
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.